from typing import List

import numpy as np
import torch
from rsl_rl.algorithms import PPO

import internutopia.core.util.gym as gymutil
import internutopia.core.util.math as math_utils
from internutopia.core.robot.articulation_action import ArticulationAction
from internutopia.core.robot.articulation_subset import ArticulationSubset
from internutopia.core.robot.controller import BaseController
from internutopia.core.robot.robot import BaseRobot
from internutopia.core.scene.scene import IScene
from internutopia_extension.configs.controllers import AliengoMoveBySpeedControllerCfg
from internutopia_extension.controllers.models.aliengo.actor_critic import ActorCritic


class RLPolicy:
    def __init__(self, path: str) -> None:
        self.policy_cfg = {
            'class_name': 'ActorCritic',
            'init_noise_std': 1.0,
            'actor_hidden_dims': [512, 256, 128],
            'critic_hidden_dims': [512, 256, 128],
            'activation': 'elu',
        }
        self.empirical_normalization = False

        self.load(path=path)

    def load(self, path, load_optimizer=False):
        num_obs = 270
        num_critic_obs = 238
        one_step_obs = 45
        env_actions = 12
        self.actor_critic = ActorCritic(num_obs, num_critic_obs, one_step_obs, env_actions, **self.policy_cfg)
        self.alg = PPO(self.actor_critic, device='cuda:0')

        loaded_dict = torch.load(path)
        self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
        if load_optimizer:
            self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'])
            self.alg.actor_critic.estimator.optimizer.load_state_dict(loaded_dict['estimator_optimizer_state_dict'])
        self.current_learning_iteration = loaded_dict['iter']
        return loaded_dict['infos']

    def get_inference_policy(self, device=None):
        self.eval_mode()  # switch to evaluation mode (dropout for example)
        if device is not None:
            self.actor_critic.to(device)
        policy = self.actor_critic.act_inference
        return policy

    def eval_mode(self):
        self.actor_critic.eval()
        if self.empirical_normalization:
            self.obs_normalizer.eval()
            self.critic_obs_normalizer.eval()


@BaseController.register('AliengoMoveBySpeedController')
class AliengoMoveBySpeedController(BaseController):
    """Controller class converting locomotion speed control action to joint positions for Aliengo robot."""

    """
    joint_names_sim and joint_names_gym define default joint orders in isaac-sim and isaac-gym.
    """

    joint_names_sim = [
        'FL_hip_joint',
        'FR_hip_joint',
        'RL_hip_joint',
        'RR_hip_joint',
        'FL_thigh_joint',
        'FR_thigh_joint',
        'RL_thigh_joint',
        'RR_thigh_joint',
        'FL_calf_joint',
        'FR_calf_joint',
        'RL_calf_joint',
        'RR_calf_joint',
    ]

    joint_names_gym = [
        'FL_hip_joint',
        'FL_thigh_joint',
        'FL_calf_joint',
        'FR_hip_joint',
        'FR_thigh_joint',
        'FR_calf_joint',
        'RL_hip_joint',
        'RL_thigh_joint',
        'RL_calf_joint',
        'RR_hip_joint',
        'RR_thigh_joint',
        'RR_calf_joint',
    ]

    def __init__(self, config: AliengoMoveBySpeedControllerCfg, robot: BaseRobot, scene: IScene) -> None:
        super().__init__(config=config, robot=robot, scene=scene)

        self._policy = RLPolicy(path=config.policy_weights_path).get_inference_policy(device='cpu')
        self.gym_adapter = gymutil.gym_adapter(self.joint_names_gym, self.joint_names_sim)
        self.joint_subset = None
        self.joint_names = config.joint_names
        if self.joint_names is not None:
            self.joint_subset = ArticulationSubset(self.robot.articulation, self.joint_names)
        self._old_joint_positions = np.zeros(12)
        self.policy_input_obs_num = 270
        self._old_policy_obs = np.zeros(self.policy_input_obs_num)
        self._apply_times_left = (
            0  # Specifies how many times the action generated by the policy needs to be repeatedly applied.
        )

    def forward(
        self,
        forward_speed: float = 0,
        rotation_speed: float = 0,
        lateral_speed: float = 0,
    ) -> ArticulationAction:
        if self._apply_times_left > 0:
            self._apply_times_left -= 1
            if self.joint_subset is None:
                return ArticulationAction(joint_positions=self.applied_joint_positions)
            return self.joint_subset.make_articulation_action(
                joint_positions=self.applied_joint_positions, joint_velocities=None
            )

        # Get obs for policy.
        robot_base = self.robot.get_robot_base()
        base_pose_w = robot_base.get_pose()
        base_quat_w = torch.tensor(base_pose_w[1]).reshape(1, -1)
        base_ang_vel_w = torch.tensor(robot_base.get_angular_velocity()[:]).reshape(1, -1)
        base_ang_vel = np.array(math_utils.quat_rotate_inverse(base_quat_w, base_ang_vel_w).reshape(-1))

        projected_gravity = torch.tensor([[0.0, 0.0, -1.0]], device='cpu', dtype=torch.float)
        projected_gravity = np.array(math_utils.quat_rotate_inverse(base_quat_w, projected_gravity).reshape(-1))
        joint_pos = (
            self.joint_subset.get_joint_positions()
            if self.joint_subset is not None
            else self.robot.articulation.get_joint_positions()
        )
        joint_vel = (
            self.joint_subset.get_joint_velocities()
            if self.joint_subset is not None
            else self.robot.articulation.get_joint_velocities()
        )
        default_dof_pos = np.array(
            [0.0, 0.0, 0.0, 0.0, 0.8000, 0.8000, 0.8000, 0.8000, -1.5000, -1.5000, -1.5000, -1.5000]
        )

        joint_pos -= default_dof_pos

        # Set action command.
        tracking_command = np.array([forward_speed, lateral_speed, rotation_speed], dtype=np.float32)

        raw_policy_obs = np.concatenate(
            [
                tracking_command * np.array([2.0, 2.0, 0.25]),
                base_ang_vel * 0.25,
                projected_gravity,
                self.gym_adapter.sim2gym(joint_pos),
                self.gym_adapter.sim2gym(joint_vel * 0.05),
                self.gym_adapter.sim2gym(self._old_joint_positions.reshape(12)),
            ]
        )
        policy_obs = np.concatenate([raw_policy_obs, self._old_policy_obs[:-45]])
        self._old_policy_obs = policy_obs
        policy_obs = policy_obs.reshape(1, 270)

        # Infer with policy.
        with torch.inference_mode():
            joint_positions: np.ndarray = (
                self._policy(torch.tensor(policy_obs, dtype=torch.float32).to('cpu')).detach().numpy() * 0.5
            )
            joint_positions = joint_positions[0]
            joint_positions = self.gym_adapter.gym2sim(joint_positions)
            self._old_joint_positions = joint_positions * 2
            self.applied_joint_positions = joint_positions + default_dof_pos
            self._apply_times_left = 3

        if self.joint_subset is None:
            return ArticulationAction(joint_positions=self.applied_joint_positions)
        return self.joint_subset.make_articulation_action(
            joint_positions=self.applied_joint_positions, joint_velocities=None
        )

    def action_to_control(self, action: List | np.ndarray) -> ArticulationAction:
        """
        Args:
            action (List | np.ndarray): 3-element 1d array containing:
              0. forward_speed (float)
              1. lateral_speed (float)
              2. rotation_speed (float)

        Returns:
            ArticulationAction: joint positions to apply.
        """
        assert len(action) == 3, 'action must contain 3 elements'
        return self.forward(
            forward_speed=action[0],
            lateral_speed=action[1],
            rotation_speed=action[2],
        )
